import torch
import torch.nn as nn
import torchvision.models as models


class Encoder(nn.Module):

    def __init__(self, embedding_size):

        super(Encoder, self).__init__()

        resnet = models.resnet152(pretrained=True)

        # Remove the fully connected layers, since we don't need the original resnet classes anymore
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

        # Create a new fc layer based on the embedding size
        self.linear = nn.Linear(
            in_features=resnet.fc.in_features, out_features=embedding_size)
        self.BatchNorm = nn.BatchNorm1d(
            num_features=embedding_size, momentum=0.01)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.BatchNorm(self.linear(features))
        return features

class MyLSTMCell(nn.Module):
    def __init__(self, embedding_size, hidden_size, device):
        super(MyLSTMCell, self).__init__()
        self.lstm_cell = nn.LSTMCell(embedding_size, hidden_size, bias=True)
        self.attention = Attention(embedding_size // 2, hidden_size, 0)

    def forward(self, state, current_embedding, image_feat=None, attention_flag=True):
        h, c = state
        if attention_flag:
            att_feat, att = self.attention(current_embedding, h, image_feat)
            input_feat = torch.cat((att_feat, current_embedding), dim=1) # combine image and text input
        else:
            input_feat = torch.cat((image_feat.squeeze(0), current_embedding), dim=1)
        h_out, c_out = self.lstm_cell(input_feat, (h, c))
        return h_out, c_out


class Attention(nn.Module):
    def __init__(self, embedding_size, hidden_size, dropout):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.dropout = dropout

        print('embedding size: ' + str(embedding_size) + ', hidden size: ' + str(hidden_size))

        self.fc_q = nn.Linear(embedding_size, hidden_size)
        self.fc_k = nn.Linear(hidden_size, hidden_size)
        self.fc_v = nn.Linear(embedding_size, hidden_size)

        self.fc_o = nn.Linear(hidden_size, 256)
        self.dropout= nn.Dropout(dropout)

    def forward(self, query, key, value):
        batch_size = query.size(0)

        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)

        Q = Q.view(batch_size, -1, self.hidden_size).permute(0, 2, 1)
        K = K.view(batch_size, -1, self.hidden_size).permute(0, 2, 1)
        V = V.view(batch_size, -1, self.hidden_size).permute(0, 2, 1)
        energy = torch.matmul(Q, K.permute(0, 2, 1))
        attention = torch.softmax(energy, dim=-1)
        x = torch.matmul(self.dropout(attention), V)
        x = x.permute(0, 2, 1)
        x = self.fc_o(x)
        return x.squeeze(1), attention

# https://github.com/njchoma/transformer_image_caption/blob/master/src/models/simple_model.py
class Decoder(nn.Module):

    def __init__(self, embed_size, hidden_size, vocab_size, device):
        super(Decoder, self).__init__()
        self.device = device
        self.embed_size = embed_size
        self.embed = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=embed_size)
        self.lstm = MyLSTMCell(embedding_size=embed_size * 2, hidden_size=hidden_size, device=device)
        self.linear = nn.Linear(in_features=hidden_size,
                                out_features=vocab_size)
        self.hidden_size = hidden_size

    """
    features: batch * feature_size
    captions: batch * max_token_length
    length: batch size * 1, representing real token length
    """
    def forward(self, features, captions, length, attention_flag=True):
        embeddings = []
        for caption in captions:
            embeddings.append(self.embed(caption))
        features = features.unsqueeze(1) # batch size * 1 * feature size
        for i in range(0, features.size(0)):
            temp = torch.zeros((1, 1, 256)).cuda()
            embeddings[i] = torch.cat((temp, embeddings[i]), 1).contiguous()
        '''
        由于将feature map放在首位 后续的每一个token都与feature map产生联系
        '''
        hidden = None
        for j, embedding in enumerate(embeddings):
            h = torch.zeros((1, self.hidden_size)).cuda()
            c = torch.zeros((1, self.hidden_size)).cuda()
            for i in range(embedding.size(1)):
                if i == 0:
                    h, c = self.lstm((h, c), embedding[:, i], features[j], attention_flag=attention_flag)
                else:
                    h, c = self.lstm((h, c), embedding[:, i], features[j], attention_flag=attention_flag)
                if i == embedding.size(1) - 1:
                    continue
                if hidden is None:
                    hidden = h
                else:
                    hidden = torch.cat((hidden, h), 0)

        output = self.linear(hidden.squeeze(0).contiguous())  # (batch x tokens_length) * labels_count
        return output

    def sample(self, features, states=None, longest_sentence_length=100, attention_flag=True):
        sampled_ids = []
        image_feat = features.unsqueeze(0) #batch size * 256
        h = torch.zeros((1, self.hidden_size)).to(self.device)
        c = torch.zeros((1, self.hidden_size)).to(self.device)
        embedding = torch.zeros((1, 256)).cuda()
        for i in range(longest_sentence_length):
            if i == 0:
                h, c = self.lstm((h, c), embedding, image_feat, attention_flag=attention_flag)
            else:
                h, c = self.lstm((h, c), embedding, image_feat, attention_flag=attention_flag)
            output = self.linear(h.squeeze(1))
            predicted = output.max(dim=1, keepdim=True)[1]
            sampled_ids.append(predicted)
            embedding = self.embed(predicted).squeeze(1)

        sampled_ids = torch.cat(sampled_ids, 1)

        return sampled_ids.squeeze()


if __name__ == "__main__":
    device = "cuda"
    decoder = Decoder(256, 512, 19, device)
    decoder.to(device)
    captions = []
    caption = torch.tensor([0,  4,  5,  6,  7,  8,  7,  8,  7,  8,  9, 10,  5, 11,  5, 12,  7, 13,
          7, 14,  9,  9, 10,  5, 15,  5, 12,  7, 13,  7, 14,  9, 15,  5, 12,  7,
         13,  7, 14,  9,  9, 10,  5, 15,  5, 12,  7, 13,  7, 16,  9, 15,  5, 12,
          7, 13,  7, 14,  9,  9,  1]).to(device)
    captions.append(caption)
    image_feat = torch.rand(1,  256).to(device)
    output = decoder(image_feat, captions, None)
    print(output.size())

